/*
 * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "prepareCustomMask.h"
#include <cstdint>
#include <cub/cub.cuh>
#include <cuda_runtime.h>

namespace tensorrt_llm
{
namespace kernels
{
////////////////////////////////////////////////////////////////////////////////////////////////////

__device__ __host__ inline int32_t ceilDiv(int32_t a, int32_t b)
{
    return (a + b - 1) / b;
}

// Input: customMaskInput (generalPackedCustoMaskPtr) shape: [batch_size, seqLenQ, ceilDiv(seqLenKv-firstSparse, 32)]
// Output: customMaskInput shape:[batch_size,numTilesQ, numTilesKv, numInstsQ, numInstsKv, tileSizeQ, tileSizeKv]
// Output: customMaskOffsets shape:[batch_size]
// Output: firstSparseMaskOffsetsKv shape:[batch_size]
__global__ void prepareCustomMaskBuffersKernelForKeepsMmaAb(
    TllmGenFmhaRunnerParams runnerParams, TllmGenFmhaKernelMetaInfo kernelMeta)
{
    int32_t const batchSize = runnerParams.mBatchSize;
    int32_t const numHeadsQPerKv = runnerParams.mNumHeadsQPerKv;
    int32_t const tileSizeQ = kernelMeta.mTileSizeQ;
    int32_t const tileSizeKv = kernelMeta.mTileSizeKv;
    int32_t const numInstsQ = kernelMeta.mStepQ / kernelMeta.mTileSizeQ;
    int32_t const numInstsKv = kernelMeta.mStepKv / kernelMeta.mTileSizeKv;
    int32_t const tileSizeQPerCta = kernelMeta.mStepQ;
    int32_t const tileSizeKvPerCta = kernelMeta.mStepKv;

    int32_t const* seqLensKvPtr = runnerParams.seqLensKvPtr;
    int64_t* customMaskOffsetsPtr = runnerParams.customMaskOffsetsPtr;
    uint32_t* customMaskPtr = runnerParams.customMaskPtr;
    int32_t const* customMaskInputPtr = runnerParams.generalPackedCustoMaskPtr;
    int32_t* firstSparseMaskOffsetsKvPtr = runnerParams.firstSparseMaskOffsetsKvPtr;

    int32_t const batchIdx = static_cast<int32_t>(blockIdx.x);
    int32_t const qThreadIdx = static_cast<int32_t>(threadIdx.x);
    int32_t const qGroupIdx = static_cast<int32_t>(blockIdx.y);
    int32_t const kvThreadIdx = static_cast<int32_t>(threadIdx.y);
    int32_t const kvGroupIdx = static_cast<int32_t>(blockIdx.z);

    if (batchIdx >= batchSize)
    {
        return;
    }
    // The first sparseMask offset in the Kv sequence dimension.
    int32_t const firstSparseMaskOffsetKv = firstSparseMaskOffsetsKvPtr[batchIdx];
    int32_t const firstSparseMaskTileOffsetKv = firstSparseMaskOffsetKv / tileSizeKvPerCta;
    int32_t const adjustedFirstSparseMaskOffsetKv = firstSparseMaskTileOffsetKv * tileSizeKvPerCta;

    // The sequence length of tensor Q.
    int32_t const seqLenQ = runnerParams.seqlensQPtr[batchIdx];
    // The sequence length of tensor KV.
    int32_t const seqLenKv = seqLensKvPtr[batchIdx];

    // Calculate global Q token index (flattened across heads)
    int32_t const qTokensPerBlock = static_cast<int32_t>(blockDim.x);
    int32_t const flattenedQIdx = qGroupIdx * qTokensPerBlock + qThreadIdx;
    int32_t const totalQTokens = seqLenQ * numHeadsQPerKv;

    if (flattenedQIdx >= totalQTokens)
    {
        return;
    }

    int32_t const tokenIdxQ = flattenedQIdx / numHeadsQPerKv;
    int32_t const headIdxInGrp = flattenedQIdx % numHeadsQPerKv;

    // Iterate from adjustedFirstSparseMaskOffsetKv to seqLenKv
    int32_t const kvTokensPerBlock = static_cast<int32_t>(blockDim.y);
    int32_t const globalKvIdx = kvGroupIdx * kvTokensPerBlock + kvThreadIdx;
    int32_t const tokenIdxKv = adjustedFirstSparseMaskOffsetKv + globalKvIdx;

    // Check KV bounds
    if (tokenIdxKv >= seqLenKv)
    {
        return;
    }

    // Get the mask value for this (Q, KV) pair
    int32_t randomMask = 0;
    if (tokenIdxKv < firstSparseMaskOffsetKv)
    {
        // Dense region: always attend
        randomMask = 1;
    }
    else
    {
        // Sparse region: check the input mask
        // Input mask shape: [bs, seqLenQ, ceilDiv(seqLenQ, 32)]
        // The KV dimension in the mask corresponds to Q positions (tree mask)
        int32_t const qPosInTree = tokenIdxKv - firstSparseMaskOffsetKv;
        if (qPosInTree < seqLenQ)
        {
            int32_t const qMaskBaseIdx = (batchIdx * seqLenQ + tokenIdxQ) * ceilDiv(seqLenQ, 32);
            int32_t const packedMaskIdx = qMaskBaseIdx + (qPosInTree >> 5);
            int32_t const bitPos = qPosInTree & 0x1F;
            randomMask = (customMaskInputPtr[packedMaskIdx] >> bitPos) & 1;
        }
    }

    if (randomMask)
    {
        int32_t const numCustomMaskTilesKv = ceilDiv(seqLenKv, tileSizeKvPerCta) - firstSparseMaskTileOffsetKv;
        int64_t const customMaskOffset = customMaskOffsetsPtr[batchIdx];
        uint32_t* localCustomMaskPtr = customMaskPtr + customMaskOffset;

        // Calculate Q indices in the custom mask
        int32_t const customMaskTokenIdxQ = tokenIdxQ * numHeadsQPerKv + headIdxInGrp;
        int32_t const tileIdxQ = customMaskTokenIdxQ / tileSizeQPerCta;
        int32_t const instIdxQ = (customMaskTokenIdxQ % tileSizeQPerCta) / tileSizeQ;
        int32_t const tokenIdxInTileQ = (customMaskTokenIdxQ % tileSizeQPerCta) % tileSizeQ;

        // Calculate KV indices in the custom mask
        int32_t const customMaskTokenIdxKv = tokenIdxKv - adjustedFirstSparseMaskOffsetKv;
        int32_t const tileIdxKv = customMaskTokenIdxKv / tileSizeKvPerCta;
        int32_t const instIdxKv = (customMaskTokenIdxKv % tileSizeKvPerCta) / tileSizeKv;
        int32_t const tokenIdxInTileKv = (customMaskTokenIdxKv % tileSizeKvPerCta) % tileSizeKv;

        // Calculate final mask offset
        int64_t const tileBase = static_cast<int64_t>(tileIdxQ) * numCustomMaskTilesKv;
        int64_t const tileOffset = tileBase + tileIdxKv;
        int64_t const instOffset = tileOffset * numInstsQ * numInstsKv + (instIdxQ * numInstsKv + instIdxKv);
        int64_t const maskOffset
            = instOffset * tileSizeQ * tileSizeKv + (tokenIdxInTileQ * tileSizeKv + tokenIdxInTileKv);
        // The offset of uint32_t custom mask
        int64_t const offsetAsUInt32 = maskOffset >> 5;
        int32_t const bitPosInUInt32 = maskOffset & 0x1F;
        // Set the bit in uint32_t custom mask
        atomicOr(&localCustomMaskPtr[offsetAsUInt32], (1U << bitPosInUInt32));
    }
}

__global__ void computeCustomMaskOffsetsKernel(
    TllmGenFmhaKernelMetaInfo kernelMeta, TllmGenFmhaRunnerParams runnerParams, unsigned long long* globalCounter)
{
    int32_t batchSize = runnerParams.mBatchSize;
    int32_t numHeadsQPerKv = runnerParams.mNumHeadsQPerKv;
    int32_t tileSizeQPerCta = kernelMeta.mStepQ;
    int32_t tileSizeKvPerCta = kernelMeta.mStepKv;
    int32_t const* seqLensKvPtr = runnerParams.seqLensKvPtr;
    int32_t const* firstSparseMaskOffsetsKvPtr = runnerParams.firstSparseMaskOffsetsKvPtr;

    typedef cub::BlockScan<int64_t, 128> BlockScan;
    __shared__ typename BlockScan::TempStorage temp_storage;

    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int64_t maskSize = 0;

    if (idx < batchSize)
    {

        int32_t seqLenQ = runnerParams.seqlensQPtr[idx];
        int32_t seqLenKv = seqLensKvPtr[idx];
        int32_t firstSparseMaskOffsetKv = firstSparseMaskOffsetsKvPtr[idx];

        int32_t numTilesQ = (seqLenQ * numHeadsQPerKv + tileSizeQPerCta - 1) / tileSizeQPerCta;
        int32_t firstSparseTile = firstSparseMaskOffsetKv / tileSizeKvPerCta;
        int32_t numCustomMaskTilesKv = (seqLenKv + tileSizeKvPerCta - 1) / tileSizeKvPerCta - firstSparseTile;

        maskSize = static_cast<int64_t>(numTilesQ * numCustomMaskTilesKv * kernelMeta.mStepQ * kernelMeta.mStepKv / 32);
    }

    int64_t prefixOffset;
    int64_t blockSum;
    BlockScan(temp_storage).ExclusiveSum(maskSize, prefixOffset, blockSum);

    __shared__ unsigned long long blockBase;
    if (threadIdx.x == 0)
        blockBase = atomicAdd(globalCounter, (unsigned long long) blockSum);
    __syncthreads();

    if (idx < batchSize)
        runnerParams.customMaskOffsetsPtr[idx] = static_cast<int64_t>(blockBase) + prefixOffset;
}

void launchComputeCustomMaskOffsetsKernel(
    TllmGenFmhaKernelMetaInfo const& kernelMeta, TllmGenFmhaRunnerParams const& runnerParams, cudaStream_t stream)
{

    int32_t batchSize = runnerParams.mBatchSize;

    unsigned long long* d_globalCounter;
    cudaMallocAsync(&d_globalCounter, sizeof(unsigned long long), stream);
    cudaMemsetAsync(d_globalCounter, 0, sizeof(unsigned long long), stream);

    int blockSize = 128;
    int gridSize = (batchSize + blockSize - 1) / blockSize;
    computeCustomMaskOffsetsKernel<<<gridSize, blockSize, 0, stream>>>(kernelMeta, runnerParams, d_globalCounter);

    cudaFreeAsync(d_globalCounter, stream);
}

// Post-processing kernel to write adjusted firstSparseMaskOffsetsKv after all work is done
__global__ void adjustFirstSparseMaskOffsetsKernel(
    TllmGenFmhaRunnerParams runnerParams, TllmGenFmhaKernelMetaInfo kernelMeta)
{
    int32_t const batchSize = runnerParams.mBatchSize;
    int32_t const tileSizeKvPerCta = kernelMeta.mStepKv;
    int32_t const idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= batchSize)
        return;
    int32_t* firstSparseMaskOffsetsKvPtr = runnerParams.firstSparseMaskOffsetsKvPtr;
    int32_t const firstSparseMaskOffsetKv = firstSparseMaskOffsetsKvPtr[idx];
    // It needs to be adjusted to multiple of tileSizeKvPerCta
    int32_t const adjusted = (firstSparseMaskOffsetKv / tileSizeKvPerCta) * tileSizeKvPerCta;
    firstSparseMaskOffsetsKvPtr[idx] = adjusted;
}

void launchPrepareCustomMaskBuffersKernelForKeepsMmaAb(
    TllmGenFmhaRunnerParams const& runnerParams, TllmGenFmhaKernelMetaInfo const& kernelMeta, cudaStream_t stream)
{
    int32_t const batchSize = runnerParams.mBatchSize;
    int32_t const maxSeqLenQ = runnerParams.mMaxSeqLenQ;
    int32_t const numHeadsQPerKv = runnerParams.mNumHeadsQPerKv;
    int32_t const tileSizeKvPerCta = kernelMeta.mStepKv;

    // Total Q tokens (flattened across heads)
    int32_t const maxTotalQTokens = maxSeqLenQ * numHeadsQPerKv;

    // Calculate the maximum KV range to process
    // The actual range is [adjustedFirstSparseMaskOffsetKv, seqLenKv)
    // adjustedFirstSparseMaskOffsetKv <= firstSparseMaskOffsetKv = seqLenKv - seqLenQ
    // So the maximum range length is: seqLenKv - adjustedFirstSparseMaskOffsetKv <= maxSeqLenQ + (tileSizeKvPerCta - 1)
    int32_t const maxKvRangeLength = maxSeqLenQ + (tileSizeKvPerCta - 1);

    int32_t const qTokensPerBlock = 64;
    int32_t const kvTokensPerBlock = 4;

    int32_t const numBlocksY = ceilDiv(maxTotalQTokens, qTokensPerBlock);
    int32_t const numBlocksZ = ceilDiv(maxKvRangeLength, kvTokensPerBlock);

    dim3 gridDim(batchSize, numBlocksY, numBlocksZ);
    dim3 blockDim(qTokensPerBlock, kvTokensPerBlock, 1);

    prepareCustomMaskBuffersKernelForKeepsMmaAb<<<gridDim, blockDim, 0, stream>>>(runnerParams, kernelMeta);
    // Ensure adjusted firstSparse offsets are written only after all blocks finish
    {
        int const blockSize = 128;
        int const gridSize = (batchSize + blockSize - 1) / blockSize;
        adjustFirstSparseMaskOffsetsKernel<<<gridSize, blockSize, 0, stream>>>(runnerParams, kernelMeta);
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void runPrepareCustomMask(
    TllmGenFmhaKernelMetaInfo const& kernelMeta, TllmGenFmhaRunnerParams const& runnerParams, cudaStream_t stream)
{
    if (isKeepsMmaAbForGenerationKernel(static_cast<FmhaKernelType>(kernelMeta.mKernelType)))
    {
        int cta_tile_size = kernelMeta.mStepQ * kernelMeta.mStepKv;
        if (cta_tile_size > 128 * 128 * 2)
        {
            TLLM_LOG_ERROR(
                "TRTLLM-GEN needs larger buffer for  custom mask preparation please enlarge it according to the "
                "formula: tile_size_q * tile_size_k * num_instances_q * num_instances_k");
            return;
        }
        // Step 1: Compute offsets on GPU using prefix sum
        launchComputeCustomMaskOffsetsKernel(kernelMeta, runnerParams, stream);
        // Step 2: Compute custom mask buffers
        launchPrepareCustomMaskBuffersKernelForKeepsMmaAb(runnerParams, kernelMeta, stream);
        TLLM_CUDA_CHECK(cudaGetLastError());
    }
    else
    {
        TLLM_LOG_ERROR(
            "TRTLLM-GEN does not support kernel type: %d for custom mask preparation", runnerParams.mKernelType);
    }
}

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace kernels
} // namespace tensorrt_llm
